{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "# 循环神经网络\n",
    "\n",
    "## 序列表示方法\n",
    "\n",
    "### Embedding层"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Embedding 层负责把单词编码为某个词向量$v$， 它接受的是采用数字编码的单词编号$i$，如2表示“I”， 3表示“me”等，系统总单词数量记为$N_{vocab}$， 输出长度为$n$的向量$v$：$$v=f_{\\theta}(i|N_{vocab}, n)$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=19, shape=(10, 4), dtype=float32, numpy=\n",
       "array([[ 0.02974591, -0.03443984, -0.02403122, -0.02153288],\n",
       "       [-0.03018601,  0.03954509, -0.02528677,  0.03129846],\n",
       "       [-0.01986039, -0.03339057,  0.03154424, -0.0137913 ],\n",
       "       [-0.0204801 ,  0.00376071,  0.00240863,  0.02820292],\n",
       "       [ 0.01496815,  0.04902376, -0.00579251, -0.01775237],\n",
       "       [ 0.00454669, -0.04509263, -0.00747275,  0.03134331],\n",
       "       [-0.00305223, -0.01019449, -0.02831918, -0.01832809],\n",
       "       [-0.01227068,  0.0409546 , -0.01069511, -0.04826144],\n",
       "       [-0.01523788, -0.02214501, -0.0169457 , -0.00288089],\n",
       "       [-0.04996878,  0.00628467, -0.04327526, -0.04943682]],\n",
       "      dtype=float32)>"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow.keras import layers\n",
    "\n",
    "# 生成 10 个单词的数字编码\n",
    "x = tf.range(10) \n",
    "# 打散\n",
    "x = tf.random.shuffle(x) \n",
    "# 创建共 10 个单词，每个单词用长度为 4 的向量表示的层\n",
    "net = layers.Embedding(10, 4)\n",
    "# 获取词向量\n",
    "out = net(x) \n",
    "out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Variable 'embedding/embeddings:0' shape=(10, 4) dtype=float32, numpy=\n",
       "array([[-0.00305223, -0.01019449, -0.02831918, -0.01832809],\n",
       "       [-0.0204801 ,  0.00376071,  0.00240863,  0.02820292],\n",
       "       [-0.01227068,  0.0409546 , -0.01069511, -0.04826144],\n",
       "       [ 0.01496815,  0.04902376, -0.00579251, -0.01775237],\n",
       "       [-0.04996878,  0.00628467, -0.04327526, -0.04943682],\n",
       "       [-0.01986039, -0.03339057,  0.03154424, -0.0137913 ],\n",
       "       [ 0.00454669, -0.04509263, -0.00747275,  0.03134331],\n",
       "       [-0.03018601,  0.03954509, -0.02528677,  0.03129846],\n",
       "       [-0.01523788, -0.02214501, -0.0169457 , -0.00288089],\n",
       "       [ 0.02974591, -0.03443984, -0.02403122, -0.02153288]],\n",
       "      dtype=float32)>"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    " net.embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 查看 net.embeddings 张量的可优化属性为 True，即可以通过梯度下降算法优化\n",
    "net.embeddings.trainable"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 循环神经网络\n",
    "\n",
    "&emsp;&emsp;网络结构在时间戳上折叠， 网络循环接受序列的每个特征向量$x_t$，并刷新内部状态向量$h_t$，同时形成输出$o_t$。对于这种网络结构，我们把它叫做循环网络结构(Recurrent Neural Network， 简称 RNN)。  \n",
    "&emsp;&emsp;如果使用张量$W_{xh}$、$W_{hh}$和偏置$b$来参数化$f_\\theta$网络，并按照$$h_t=\\sigma(W_{xh}x_t+W_{hh}h_{t-1}+b)$$方式更新内存状态，我们把这种网络叫做基本的循环神经网络"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## RNN层使用方法\n",
    "\n",
    "在 TensorFlow 中，可以通过 layers.SimpleRNNCell 来完成$\\sigma(W_{xh}x_t+W_{hh}h_{t-1}+b)$计算。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "###  SimpleRNNCell"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<tf.Variable 'kernel:0' shape=(4, 3) dtype=float32, numpy=\n",
       " array([[ 0.4115522 ,  0.22107458, -0.88970715],\n",
       "        [-0.04243898, -0.02677888,  0.17191768],\n",
       "        [ 0.6828035 ,  0.90170157, -0.56374496],\n",
       "        [ 0.8897619 ,  0.6086905 ,  0.06193233]], dtype=float32)>,\n",
       " <tf.Variable 'recurrent_kernel:0' shape=(3, 3) dtype=float32, numpy=\n",
       " array([[-0.47960746,  0.8641869 , -0.15217681],\n",
       "        [ 0.56623876,  0.43728444,  0.69868165],\n",
       "        [ 0.670336  ,  0.2489245 , -0.699061  ]], dtype=float32)>,\n",
       " <tf.Variable 'bias:0' shape=(3,) dtype=float32, numpy=array([0., 0., 0.], dtype=float32)>]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 创建 RNN Cell，内存向量长度为 3\n",
    "cell = layers.SimpleRNNCell(3) \n",
    "# 输出特征长度 n=4\n",
    "cell.build(input_shape=(None,4)) \n",
    "# 打印 wxh, whh, b 张量\n",
    "cell.trainable_variables "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(4, 64) (4, 64)\n"
     ]
    }
   ],
   "source": [
    "# 初始化状态向量，用列表包裹，统一格式\n",
    "h0 = [tf.zeros([4, 64])]\n",
    "# 生成输入张量， 4 个 80 单词的句子\n",
    "x = tf.random.normal([4, 80, 100]) \n",
    "# 所有句子的第 1 个单词\n",
    "xt = x[:,0,:] \n",
    "# 构建输入特征 n=100,序列长度 s=80,状态长度=64 的 Cell\n",
    "cell = layers.SimpleRNNCell(64)\n",
    "# 前向计算\n",
    "out, h1 = cell(xt, h0) \n",
    "print(out.shape, h1[0].shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "492815752 492815752\n"
     ]
    }
   ],
   "source": [
    "print(id(out), id(h1[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# h 保存每个时间戳上的状态向量列表\n",
    "h = h0 \n",
    "# 在序列长度的维度解开输入，得到 xt:[b,n]\n",
    "for xt in tf.unstack(x, axis=1):\n",
    "    # 前向计算,out 和 h 均被覆盖\n",
    "    out, h = cell(xt, h) \n",
    "# 最终输出可以聚合每个时间戳上的输出，也可以只取最后时间戳的输出\n",
    "out = out"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 多层 SimpleRNNCell 网络"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = tf.random.normal([4,80,100])\n",
    "# 取第一个时间戳的输入 x0\n",
    "xt = x[:,0,:] \n",
    "# 构建 2 个 Cell,先 cell0,后 cell1，内存状态向量长度都为 64\n",
    "cell0 = layers.SimpleRNNCell(64)\n",
    "cell1 = layers.SimpleRNNCell(64)\n",
    "# cell0 的初始状态向量\n",
    "h0 = [tf.zeros([4,64])] \n",
    "# cell1 的初始状态向量\n",
    "h1 = [tf.zeros([4,64])] "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在时间轴上面循环计算多次来实现整个网络的前向运算，每个时间戳上的输入 xt 首先通过第一层，得到输出 out0，再通过第二层，得到输出 out1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "for xt in tf.unstack(x, axis=1):\n",
    "    # xt 作为输入，输出为 out0\n",
    "    out0, h0 = cell0(xt, h0)\n",
    "    # 上一个 cell 的输出 out0 作为本 cell 的输入\n",
    "    out1, h1 = cell1(out0, h1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 保存上一层的所有时间戳上面的输出\n",
    "middle_sequences = []\n",
    "# 计算第一层的所有时间戳上的输出，并保存\n",
    "for xt in tf.unstack(x, axis=1):\n",
    "    out0, h0 = cell0(xt, h0)\n",
    "    middle_sequences.append(out0)\n",
    "# 计算第二层的所有时间戳上的输出\n",
    "# 如果不是末层，需要保存所有时间戳上面的输出\n",
    "for xt in middle_sequences:\n",
    "    out1, h1 = cell1(xt, h1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### SimpleRNN 层\n",
    "\n",
    "单层循环神经网络的前向运算"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "TensorShape([4, 64])"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 创建状态向量长度为 64 的 SimpleRNN 层\n",
    "layer = layers.SimpleRNN(64) \n",
    "x = tf.random.normal([4, 80, 100])\n",
    "# 和普通卷积网络一样，一行代码即可获得输出\n",
    "out = layer(x) \n",
    "out.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "TensorShape([4, 80, 64])"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 创建 RNN 层时，设置返回所有时间戳上的输出\n",
    "layer = layers.SimpleRNN(64,return_sequences=True)\n",
    "# 前向计算\n",
    "out = layer(x) \n",
    "# 输出，自动进行了 concat 操作\n",
    "out.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "TensorShape([4, 64])"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from tensorflow import keras\n",
    "net = keras.Sequential([ # 构建 2 层 RNN 网络\n",
    "# 除最末层外，都需要返回所有时间戳的输出，用作下一层的输入\n",
    "layers.SimpleRNN(64, return_sequences=True),\n",
    "layers.SimpleRNN(64),\n",
    "])\n",
    "out = net(x) # 前向计算\n",
    "out.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 梯度弥散和梯度爆炸"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=7857, shape=(2,), dtype=float32, numpy=array([0., 2.], dtype=float32)>"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 任意创建某矩阵\n",
    "W = tf.ones([2,2]) \n",
    "# 计算矩阵的特征值\n",
    "eigenvalues = tf.linalg.eigh(W)[0] \n",
    "eigenvalues"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[2.0, 4.0, 8.0, 16.0, 32.0, 64.0, 128.0, 256.0, 512.0, 1024.0, 2048.0]"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "val = [W]\n",
    "for i in range(10): # 矩阵相乘 n 次方\n",
    "    val.append([val[-1]@W])\n",
    "# 计算 L2 范数\n",
    "norm = list(map(lambda x:tf.norm(x).numpy(),val))\n",
    "norm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYsAAAEGCAYAAACUzrmNAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAa+ElEQVR4nO3df7Rd853/8edLJPEzRRMZkhCtrBJav24x1TH94osOTehXiVkdGTWTshSd0Sqd70zLDNXql+pM/QgiLCo12nFviV+N+jGLIj9IhGaEpFzJyJ0qiVYTSd7fP/Y+K8dxk3Nz79ln733O67HWWefcz93nnPdZ4bzu57P3fm9FBGZmZpuyRd4FmJlZ8TkszMysLoeFmZnV5bAwM7O6HBZmZlbXlnkXkJXhw4fH2LFj8y7DzKxU5syZ8z8RMaJ2vGXDYuzYscyePTvvMszMSkXSb3ob9zKUmZnV5bAwM7O6HBZmZlaXw8LMzOpyWJiZWV0OCzMzq8thYWZmdTkszMxaxK9+BZdeCitXNv61HRZmZi3i1lvhsstg8ODGv7bDwsysBURAVxcccwxsvXXjXz+zsJA0RtIvJb0oaaGk89LxnSQ9JOml9H7HqudcJGmxpEWSjqkaP0jSgvR3P5SkrOo2MyujOXPg9ddh4sRsXj/LmcVa4PyI2Bs4FDhb0njgQmBWRIwDZqU/k/5uErAPcCxwjaRB6WtdC0wBxqW3YzOs28ysdLq6YIst4Ljjsnn9zMIiIpZHxNz08SrgRWAUMBG4Jd3sFuCE9PFEYEZErI6IJcBi4GBJuwDDIuLJSC4YfmvVc8zMDOjshE9/GoYPz+b1m7LPQtJY4ADgKWBkRCyHJFCAndPNRgGvVT2tOx0blT6uHe/tfaZImi1pdk9PTyM/gplZYS1ZAvPnw4QJ2b1H5mEhaTvgp8BXI2JTB3T1th8iNjH+wcGIqRHREREdI0Z8oB27mVlL6upK7rPaXwEZh4WkwSRBcXtE/CwdfiNdWiK9X5GOdwNjqp4+GliWjo/uZdzMzEjCYvx42HPP7N4jy6OhBNwEvBgRV1b9qguYnD6eDHRWjU+SNFTSHiQ7sp9Ol6pWSTo0fc3Tqp5jZtbWfvc7ePTRbGcVkO2V8g4D/gpYIOnZdOybwOXAnZLOAF4FvgAQEQsl3Qm8QHIk1dkRsS593lnAdGBr4L70ZmbW9mbOhHXrst1fAaDkAKPW09HREb6sqpm1upNPhscfT86x2KIBa0WS5kRER+24z+A2Myup1avh/vvhc59rTFBsisPCzKykHnkEVq3Kfn8FOCzMzEqrsxO22QaOOCL793JYmJmVUNaNA2s5LMzMSmju3GwbB9ZyWJiZlVBnZ7aNA2s5LMzMSqizEw47LLvGgbUcFmZmJVNpHNisJShwWJiZlc7Pf57cOyzMzGyjOjuzbxxYy2FhZlYilcaBWfeCquWwMDMrkUrjwGYuQYHDwsysVLq64E/+BA4+uLnv67AwMyuJ1avhvvua0ziwlsPCzKwkKo0Dm72/AhwWZmalUWkceOSRzX9vh4WZWQk0u3FgLYeFmVkJNLtxYC2HhZlZCTS7cWAth4WZWQk0u3FgLYeFmVnB5dE4sJbDwsys4CqNA/M4ZLbCYWFmVnCdnbD33jBuXH41OCzMzAqs0jgwzyUocFiYmRVaXo0DazkszMwKrKsLRo5sfuPAWg4LM7OCyrNxYC2HhZlZQVUaB+a9BAUOCzOzwsqzcWAth4WZWQHl3TiwlsPCzKyAKo0D8zwRr5rDwsysgCqNA48/Pu9KEg4LM7MCyrtxYC2HhZlZwSxdmn/jwFoOCzOzgunqSu6Lsr8CHBZmZoVThMaBtRwWZmYFUpTGgbUcFmZmBXLffcVoHFjLYWFmViCdncVoHFgrs7CQNE3SCknPV419W9Lrkp5Nb39R9buLJC2WtEjSMVXjB0lakP7uh5KUVc1mZnkqUuPAWlmWMx04tpfxqyJi//Q2E0DSeGASsE/6nGskDUq3vxaYAoxLb729pplZ6RWpcWCtzMIiIh4D3uzj5hOBGRGxOiKWAIuBgyXtAgyLiCcjIoBbgROyqdjMLF9dXcVpHFgrj4nOVyTNT5epdkzHRgGvVW3TnY6NSh/XjvdK0hRJsyXN7unpaXTdZmaZqTQOPProYjQOrNXssLgW+CiwP7Ac+H/peG/7IWIT472KiKkR0RERHSNGjBhorWZmTTN3LnR3F3MJCpocFhHxRkSsi4j1wA1AZX9/NzCmatPRwLJ0fHQv42ZmLaVojQNrNTUs0n0QFScClSOluoBJkoZK2oNkR/bTEbEcWCXp0PQoqNOAzmbWbGbWDF1dxWocWGvLrF5Y0h3AZ4DhkrqBbwGfkbQ/yVLSUuDLABGxUNKdwAvAWuDsiFiXvtRZJEdWbQ3cl97MzFrG0qXw3HNwxRV5V7JxmYVFRJzay/BNm9j+UuDSXsZnA/s2sDQzs0KpNA4s6v4K8BncZma5K2LjwFoOCzOzHBW1cWAth4WZWY4qjQOLdO2K3jgszMxyVGkceMgheVeyaQ4LM7OcFLlxYK2Cl2dm1rqK3DiwlsPCzCwnRW4cWMthYWaWg6I3DqzlsDAzy0HRGwfWcliYmeWg6I0DazkszMxyUPTGgbUcFmZmTVZpHFj0E/GqOSzMzJqsDI0DazkszMyarAyNA2s5LMzMmqgsjQNrOSzMzJqoLI0DazkszMyaqCyNA2s5LMzMmqRMjQNrlaxcM7PyevTR8jQOrOWwMDNrks7O8jQOrOWwMDNrgrI1DqzlsDAza4KyNQ6s5bAwM2uCrq5kp/Zxx+VdSf9s2ZeNJHUA/wDsnj5HQETEJzKszcysZXR2wqc+BSNG5F1J//QpLIDbga8DC4D12ZVjZtZ6Ko0Dr7gi70r6r69h0RMRXZlWYmbWosrYOLBWX8PiW5JuBGYBqyuDEfGzTKoyM2shXV3laxxYq69hcTqwFzCYDctQATgszMw24a23kpPxzj8/70oGpq9hsV9EfDzTSszMWtDMmbB2bbmXoKDvh87+StL4TCsxM2tBZW0cWKuvM4tPA5MlLSHZZ+FDZ83M6lizJmkceMop5WscWKuvYXFsplWYmbWgRx5JGgeW7doVvakbFpK2AO6NiH2bUI+ZWcuoNA486qi8Kxm4uhOjiFgPPCdptybUY2bWEsreOLBWX5ehdgEWSnoa+H1lMCJaYHJlZtZ48+YljQP/+Z/zrqQx+hoWF2dahZlZi+nsLHfjwFp9CouIeFTSSOCT6dDTEbEiu7LMzMqt7I0Da/XpYC5JJwNPA18ATgaeknRSloWZmZVVpXFg2U/Eq9bXZah/AD5ZmU1IGgH8Argrq8LMzMqqFRoH1urraSJb1Cw7/bbecyVNk7RC0vNVYztJekjSS+n9jlW/u0jSYkmLJB1TNX6QpAXp734oSX2s2cwsF52dsNde5W4cWKuvYXG/pAck/bWkvwbuBWbWec50Pngy34XArIgYR9LB9kKAtJXIJGCf9DnXSBqUPudaYAowLr35BEEzK6xHH4WHH4ZTT827ksbqU1hExNeBqcAngP2AqRHxjTrPeQx4s2Z4InBL+vgW4ISq8RkRsToilgCLgYMl7QIMi4gnIyKAW6ueY2ZWKGvWwFlnwdix8LWv5V1NY/V1nwUR8VPgpwN8v5ERsTx9veWSdk7HRwG/qtquOx17L31cO94rSVNIZiHstpvPITSz5rrySnjxRfj5z5Mzt1tJX4+G+ny6n+FtSSslrZK0soF19LYfIjYx3quImBoRHRHRMaJVjlczs1JYsgQuuQROPBGOPz7vahqvr/ssvgdMiIgPRcSwiNg+Iob14/3eSJeWSO8rO827gTFV240GlqXjo3sZNzMrjAg455zkJLyrr867mmz0NSzeiIgXG/B+XcDk9PFkoLNqfJKkoZL2INmR/XS6ZLVK0qHpUVCnVT3HzKwQ7r4b7r0XLr4Yxoypv30Z9XWfxWxJPwHupo/X4JZ0B/AZYLikbuBbwOXAnZLOAF4lOcmPiFgo6U7gBWAtcHZErEtf6iySI6u2Bu5Lb2ZmhfDOO3DuufDxjyf3raqvYTEM+ANwdNXYJq/BHREbO3DsyI1sfylwaS/jswG3RzezQvr2t5OGgT/5CQwenHc12elrb6jTsy7EzKxs5s+HH/wA/uZvkj5QrWyzL/QnaW4WhZiZlcn69XDmmbDjjnD55XlXk70+n2dRxe02zKztTZsGTz4JN98MH/5w3tVkrz+XEL+34VWYmZVITw9ccAEcfjhMnlx/+1aw2WEREf83i0LMzMriggtg1Sq45hpol9am9TrHjpE0Q9Ljkr4paXDV7+7Ovjwzs2J57DGYPj3p/bTPPnlX0zz1ZhbTgEeAc0iuw/2opMrq3O4Z1mVmVjjVjQL/8R/zrqa56u3gHhER16WPz5H0ReAxSRPYRI8mM7NWdOWV8MILrdkosJ56YTFY0lYR8UeAiLhN0n8DDwDbZl6dmVlBLF3a2o0C66m3DHUjcEj1QET8gqRNx/O9PsPMrMW0Q6PAejY5s4iIqzYyPk+SD6E1s7bQ2Qn33APf/37rNgqspz/nWVT8fcOqMDMrqHfeSWYVrd4osJ7+nMFd0SZHF5tZO7v44vZoFFjPQGYWPhrKzFra/Plw1VXt0Siwnk3OLCStovdQEMn1JczMWtL69ck5Fe3SKLCeeju4t29WIWZmRTJtGjzxRPs0CqxnIMtQZmYtqacHvvGN9moUWI/DwsysxgUXwMqV7dUosB6HhZlZlXZtFFiPw8LMLFVpFLj77u3XKLCegZxnYWbWUtq5UWA9nlmYmbGhUeAJJ7Rno8B6HBZm1vbcKLA+L0OZWdurNAq84grYbbe8qykmzyzMrK1VNwo877y8qykuzyzMrK25UWDfeGZhZm3LjQL7zmFhZm3JjQI3j5ehzKwtuVHg5vHMwszajhsFbj6HhZm1HTcK3HwOCzNrK24U2D8OCzNrG24U2H/ewW1mbeOqq9wosL88szCztrB0aXICnhsF9o/DwsxanhsFDpyXocys5blR4MB5ZmFmLe2dd+Dcc90ocKA8szCzlnbxxfDaa3DHHW4UOBCeWZhZy1qwYEOjwMMOy7uacsslLCQtlbRA0rOSZqdjO0l6SNJL6f2OVdtfJGmxpEWSjsmjZjMrl/Xr4cwz3SiwUfKcWfyviNg/IjrSny8EZkXEOGBW+jOSxgOTgH2AY4FrJA3Ko2AzK4fKyXdPPJHs1HajwIEr0jLUROCW9PEtwAlV4zMiYnVELAEWAwfnUJ+ZlcCKFXDUUTB1atIDyo0CGyOvsAjgQUlzJE1Jx0ZGxHKA9H7ndHwU8FrVc7vTsQ+QNEXSbEmze3p6MirdzIpq3jzo6IBnnoHbb4fvfteNAhslr6OhDouIZZJ2Bh6S9OtNbNvbP3X0tmFETAWmAnR0dPS6jZm1phkz4EtfSpac/vM/4aCD8q6oteQys4iIZen9CuA/SJaV3pC0C0B6vyLdvBsYU/X00cCy5lVrZkW2bh1cdBGceioceCDMnu2gyELTw0LStpK2rzwGjgaeB7qAyuriZKAzfdwFTJI0VNIewDjg6eZWbWZF9PbbMGFCcrTT3/4tPPwwjByZd1WtKY9lqJHAfyhZSNwS+HFE3C/pGeBOSWcArwJfAIiIhZLuBF4A1gJnR8S6HOo2swJZtAgmToSXX04uYnTmmd4/kaWmh0VEvALs18v4b4EjN/KcS4FLMy7NzErivvuSZafBg2HWrOTyqJatIh06a2a2SRHwve/BccfBHnsk+yccFM3h3lBmVgp/+EPStuOOO+Dkk2HaNNh227yrah+eWZhZ4b36KvzZnyWHx152WXLvoGguzyzMrNAefxxOOgnefRe6unyVu7x4ZmFmhXX99XDEEfChD8FTTzko8uSwMLPCqTQCPPPMpM/T00/D3nvnXVV7c1iYWaFUGgFed13SCPCee2CHHfKuyrzPwswKY9685ES7np6kEeBf/mXeFVmFZxZmVggzZiRXs4tIGgE6KIrFYWFmuXIjwHLwMpSZ5ebtt5MZxMyZSSPAf/s3GDIk76qsNw4LM8uFGwGWi8PCzJrOjQDLx/sszKxp3AiwvDyzMLOmcCPAcvPMwswy50aA5eeZhZllJgIeeAAmT3YjwLLzzMLMGu53v4Orr4bx4+Gzn3UjwFbgsDCzhohIAuH002HXXeGrX01C4uab4bnn3Aiw7LwMZWYDsmoV/PjHSeO/Z59N9kVMngxf/jIccEDe1VmjOCzMrF+eey653sRttyWB8YlPwLXXJmdkDxuWd3XWaA4LM+uzd9+Ff//3ZBbx5JOw1VZwyinJ2deHHOIzsFuZw8LM6lq0KJlFTJ+e7Lz+2MfgqqvgtNNgp53yrs6awWFhZr1aswbuvjuZRfzyl0lrjs9/PplF/PmfexbRbhwWZvY+S5bADTfATTclV60bOxa+853kKKeRI/OuzvLisDAz1q5N2oRfdx3cf38ya/jc55JZxNFHwxY+yL7tOSzM2tjrrycziBtugO7u5PyIf/onOOMMGDMm7+qsSBwWZm1m/Xr4xS+SWURXV3KlumOOgX/91+QM6y39rWC98H8WZm2ipyc5m/r66+GVV2D4cPja15Ir1H30o3lXZ0XnsDBrQatXw4IFyfUiKrfnn09mEYcfDv/yL8mRTUOH5l2plYXDwqzk3nsPFi58fzDMn5+MA3z4w/DJT8KECTBpUtLcz2xzOSzMSmTtWvj1r98fDM8+m8wkAHbYATo64Pzzk/uODthtN58TYQPnsDArqPXr4b/+6/3BMG9ecsU5gO23h4MOgnPO2RAMH/mIg8Gy4bAwK4AIePnl9wfD3LlJgz6AbbaBAw+EKVM2BMO4cT7/wZrHYWHWZBHwm9+8PxjmzIG33kp+v9VWsP/+SZvvSjDstRcMGpRv3dbeHBZmDRQBK1fCsmUbvy1aBL/9bbL94MGw337JjudKMIwfn4ybFYnDwqyPfv/7TYdA5VbZp1Bt2LDk7Ohdd4UTT9wQDPvu68NXrRwcFtb2/vhHWL68fgisXPnB5269NYwalYRAR8eGQKi+7bILbLdd8z+XWSM5LKzU3nsv2Qm8ubeVK+GNN5IQePPND77ukCEbvuz33TdpptdbEAwb5qOPrD2UJiwkHQtcDQwCboyIy3MuyfogIjk3YPXq5PoItffvvtu/L/vKrXJ+QT1bbpkcalp923PP5Gzm3kJgp50cAmbVShEWkgYBPwL+N9ANPCOpKyJeyLeygYlIjqVfv37D47Vrk5YMlVuzfn7vvd6/0Df2Jd/X+zVrks+2uaRk6ab2C37s2A+O9eU2dKi//M0GohRhARwMLI6IVwAkzQAmAg0PiwkT4KWX3v8F3tuX+ub8vrdt+/MF2ixS8uU6ZMim74cOTZZh6m23qfuttur9y32bbXwOgVmRlCUsRgGvVf3cDRxSu5GkKcAUgN12261fb7TnnskX2BZbJDdp04/r/X5ztpWS5ZJBg5Jb9eOsfx4yZMOX+KBB/ivczN6vLGHR21fXB/42j4ipwFSAjo6Ofv3tfuWV/XmWmVlrK8tEvxuovm7XaGBZTrWYmbWdsoTFM8A4SXtIGgJMArpyrsnMrG2UYhkqItZK+grwAMmhs9MiYmHOZZmZtY1ShAVARMwEZuZdh5lZOyrLMpSZmeXIYWFmZnU5LMzMrC6HhZmZ1aUoct+JAZDUA/ymn08fDvxPA8spA3/m9tBun7ndPi8M/DPvHhEjagdbNiwGQtLsiOjIu45m8mduD+32mdvt80J2n9nLUGZmVpfDwszM6nJY9G5q3gXkwJ+5PbTbZ263zwsZfWbvszAzs7o8szAzs7ocFmZmVpfDooqkYyUtkrRY0oV515M1SWMk/VLSi5IWSjov75qaRdIgSfMk3ZN3Lc0gaQdJd0n6dfrv/ad515Q1SX+X/nf9vKQ7JG2Vd02NJmmapBWSnq8a20nSQ5JeSu93bMR7OSxSkgYBPwI+C4wHTpU0Pt+qMrcWOD8i9gYOBc5ug89ccR7wYt5FNNHVwP0RsRewHy3+2SWNAs4FOiJiX5JLG0zKt6pMTAeOrRm7EJgVEeOAWenPA+aw2OBgYHFEvBIRa4AZwMSca8pURCyPiLnp41UkXyCj8q0qe5JGA8cBN+ZdSzNIGgYcDtwEEBFrIuKtfKtqii2BrSVtCWxDC15dMyIeA96sGZ4I3JI+vgU4oRHv5bDYYBTwWtXP3bTBF2eFpLHAAcBT+VbSFD8ALgDW511Ik3wE6AFuTpfebpS0bd5FZSkiXge+D7wKLAfejogH862qaUZGxHJI/iAEdm7EizosNlAvY21xXLGk7YCfAl+NiJV515MlSccDKyJiTt61NNGWwIHAtRFxAPB7GrQ0UVTpOv1EYA9gV2BbSV/Mt6pyc1hs0A2Mqfp5NC04ba0laTBJUNweET/Lu54mOAyYIGkpyVLjEZJuy7ekzHUD3RFRmTXeRRIerewoYElE9ETEe8DPgE/lXFOzvCFpF4D0fkUjXtRhscEzwDhJe0gaQrIzrCvnmjIlSSTr2C9GxJV519MMEXFRRIyOiLEk/8YPR0RL/8UZEf8NvCbpY+nQkcALOZbUDK8Ch0raJv3v/EhafKd+lS5gcvp4MtDZiBctzTW4sxYRayV9BXiA5MiJaRGxMOeysnYY8FfAAknPpmPfTK93bq3lHOD29A+hV4DTc64nUxHxlKS7gLkkR/3NowVbf0i6A/gMMFxSN/At4HLgTklnkITmFxryXm73YWZm9XgZyszM6nJYmJlZXQ4LMzOry2FhZmZ1OSzMzKwuh4VZA0k6oboZo6RLJB2VZ01mjeBDZ80aSNJ04J6IuCvvWswayTMLsz6QNDa9DsQN6TUSHpS0dc02nwImAFdIelbSRyVNl3RS+vulki6T9KSk2ZIOlPSApJclnVn1Ol+X9Iyk+ZIuTse2lXSvpOfS6zOc0szPb+awMOu7ccCPImIf4C3g/1T/MiKeIGm18PWI2D8iXu7lNV6LiD8FHie5FsFJJNcSuQRA0tHp+xwM7A8cJOlwkmsWLIuI/dLrM9yfwecz2yi3+zDruyURUWmLMgcY24/XqPQbWwBsl15HZJWkP0raATg6vc1Lt9uOJDweB74v6bsky1yP9/MzmPWLw8Ks71ZXPV4HbL2xDfvwGutrXm89yf+PAr4TEdfXPlHSQcBfAN+R9GBEXNKP9zfrFy9DmTXWKmD7ATz/AeBL6TVGkDRK0s6SdgX+EBG3kVzUp9VbjFvBeGZh1lgzgBsknUuyP2KzRMSDkvYGnkw6a/MO8EVgT5Id5+uB94CzGleyWX0+dNbMzOryMpSZmdXlsDAzs7ocFmZmVpfDwszM6nJYmJlZXQ4LMzOry2FhZmZ1/X/S4fkXTf8/wAAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "x = range(11)\n",
    "\n",
    "plt.plot(x, norm, color='blue')\n",
    "plt.xlabel('n times')\n",
    "plt.ylabel('L2-norm')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "考虑最大特征值小于1时的情况"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=7948, shape=(2,), dtype=float32, numpy=array([0. , 0.8], dtype=float32)>"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 任意创建某矩阵\n",
    "W = tf.ones([2,2])*0.4 \n",
    "# 计算特征值\n",
    "eigenvalues = tf.linalg.eigh(W)[0] \n",
    "eigenvalues"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEGCAYAAABo25JHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nO3deXgV9d3+8fcnGySQBEISlhBIWAQCymIEQVBcqFjXqqBo664P7tXWVrta2+dnF21rK9aFWmxdccel4K4oFgn7JhoBSQBJILKvIZ/fHzn6pBggQCaTk3O/rutc5JyZM7nPBeTOzHxnvubuiIhI7IoLO4CIiIRLRSAiEuNUBCIiMU5FICIS41QEIiIxLiHsAAcqMzPT8/Lywo4hIhJVZs6cudbds2pbFnVFkJeXR1FRUdgxRESiipl9vrdlOjQkIhLjVAQiIjFORSAiEuNUBCIiMU5FICIS4wItAjMbaWZLzKzYzG6tZXm6mb1kZnPNbKGZXRpkHhER+abAisDM4oFxwClAATDGzAr2WO1aYJG79wWGA3ebWVJQmURE5JuC3CMYCBS7+1J33wk8CZy5xzoOpJqZAS2BCqAyiDCrN2zj9kkL2bW7KojNi4hErSCLIAcoqfG8NPJaTfcCvYBVwHzgRnf/xk9qM7vKzIrMrKi8vPygwswt2cCEacv561vFB/V+EZGmKsgisFpe23MWnJOBOUAHoB9wr5mlfeNN7g+6e6G7F2Zl1XqF9H6N7NOOs/vnMO7tYuaUrD+obYiINEVBFkEpkFvjeUeqf/Ov6VLgOa9WDCwDegYV6Jdn9CY7tRk3T5zDtp27g/o2IiJRJcgimAF0N7P8yAng84FJe6yzAjgRwMzaAj2ApUEFSk9O5A/n9mVp+RZ+N/njoL6NiEhUCawI3L0SuA6YAiwGJrr7QjMba2ZjI6v9GhhiZvOBN4Efu/vaoDIBDO2eySVD8pgwbTkfFAf6rUREooJF2+T1hYWFfqh3H922czen/mUq23btZvL3jyU9ObGe0omINE5mNtPdC2tbFpNXFicnxfPH8/pRtmkHv3ppYdhxRERCFZNFANAvtxXXDu/Kc7NWMnnB6rDjiIiEJmaLAOC6E7rTJyeNnzy/gPJNO8KOIyISipgugqSEOP44uh+bd1Ry23PzibbzJSIi9SGmiwDgsLap/OjkHryxeA3PzCwNO46ISIOL+SIAuOyYfAblZ/CrlxZR+uXWsOOIiDQoFQEQF2fcNaov7s4Pn55LVZUOEYlI7FARRORmpPDL03vzn6UV/GPa8rDjiIg0GBVBDaMKO3JSr2x+N/ljPl2zKew4IiINQkVQg5lx59lH0LJZAjdPnKu5C0QkJqgI9pCV2oz/PasP81du4F7NXSAiMUBFUItTDm/P2f1zuPftYuZq7gIRaeJUBHvx1dwFN02cw/ZdmrtARJouFcFeaO4CEYkVKoJ9GNo9k4sHd+YfHyxnmuYuEJEmSkWwH7ee0osumS245Zl5bNy+K+w4IiL1TkWwH8lJ8dw9ui+rN2zjjpcWhR1HRKTeqQjqoH+n1lx7fDeemVnKlIVfhB1HRKReqQjq6PoTutO7Qxo/eW4+azdr7gIRaToCLQIzG2lmS8ys2MxurWX5LWY2J/JYYGa7zSwjyEwHKykhjj+d149NmrtARJqYwIrAzOKBccApQAEwxswKaq7j7n9w937u3g+4DXjX3SuCynSoDmubyi3f6sHri9bw7KyVYccREakXQe4RDASK3X2pu+8EngTO3Mf6Y4AnAsxTLy4bms/A/Ax+NWmh5i4QkSYhyCLIAUpqPC+NvPYNZpYCjASe3cvyq8ysyMyKysvL6z3ogYiPM+4e1Zcqd255ep7mLhCRqBdkEVgtr+3tp+bpwAd7Oyzk7g+6e6G7F2ZlZdVbwIOVm5HCL04v4MOl65iguQtEJMoFWQSlQG6N5x2BVXtZ93yi4LBQTaMLczmxZ/XcBcVlmrtARKJXkEUwA+huZvlmlkT1D/tJe65kZunAccCLAWapd2bGneccTkpSvOYuEJGoFlgRuHslcB0wBVgMTHT3hWY21szG1lj1O8Br7r4lqCxByU5tzv/7zuHMK93AuLc1d4GIRCeLtvHwhYWFXlRUFHaM/3LTU3N4ae4qnrtmCEd0bBV2HBGRbzCzme5eWNsyXVlcD24/ozeZLZtx88S5mrtARKKOiqAepCcn8odRR1Bctpk/TFkSdhwRkQOiIqgnw7pncfHgzvz9/WVM+0xzF4hI9FAR1KNbT+lFfmYLbnl6Hps0d4GIRAkVQT3S3AUiEo1UBPVsQKfWXDO8G0/PLOU1zV0gIlFARRCAG07sTkH7NG7T3AUiEgVUBAH4eu6C7ZX89HnNXSAijZuKICA92qXyw5MPY8rCNTynuQtEpBFTEQTo8qFdGJiXwe2TFrJy/baw44iI1EpFEKD4OOPu0dVzF1z72CxddSwijZKKIGC5GSncPbofc0vXc/PEOZrIRkQaHRVBAxjZpx0/OaUXr87/gt/rFhQi0sgkhB0gVlwxLJ/l67Zw/7uf0blNCmMGdgo7kogIoCJoMGbGr87oTemX2/jZCwvo2DqZYd3Dn3ZTRESHhhpQQnwc917Qn+7ZLbnm0Vks+UJTXIpI+FQEDSy1eSIPX3IUyUnxXDZhBmWbtocdSURinIogBB1aJfP3i4+iYstOrnykiG07NaxURMITaBGY2UgzW2JmxWZ2617WGW5mc8xsoZm9G2SexuTwjun8ZUx/5q3cwPefmq1hpSISmsCKwMzigXHAKUABMMbMCvZYpxVwH3CGu/cGRgWVpzEaUdCWn59awJSFa/jt5I/DjiMiMSrIUUMDgWJ3XwpgZk8CZwI1b9R/AfCcu68AcPeyAPM0Spcek8fn67bw4HtL6ZSRwneP7hx2JBGJMUEeGsoBSmo8L428VtNhQGsze8fMZprZRQHmaZTMjJ+fVsAJPbP55aSFvLMk5rpQREIWZBFYLa/teSA8ATgSOBU4Gfi5mR32jQ2ZXWVmRWZWVF5eXv9JQ5YQH8dfxvTnsLapXPf4bBav3hh2JBGJIUEWQSmQW+N5R2BVLetMdvct7r4WeA/ou+eG3P1Bdy9098KsrKZ5EVbLZgk8fEkhLZpVDytds1HDSkWkYQRZBDOA7maWb2ZJwPnApD3WeREYZmYJZpYCDAIWB5ipUWufXj2sdMO2XVz+yAy27qwMO5KIxIDAisDdK4HrgClU/3Cf6O4LzWysmY2NrLMYmAzMAz4Cxrv7gqAyRYM+Oen8dUx/Fq3ayA1PzGG3hpWKSMAs2qZRLCws9KKiorBjBO6Racv55aSFXHZMPr84vWD/bxAR2Qczm+nuhbUt003nGqmLh+SxfN0WHv5gGXmZKVw0OC/sSCLSRKkIGrGfnVpAScVWbp+0kNzWKRzfMzvsSCLSBOleQ41YfJxxz/n9KeiQxnWPz2Lhqg1hRxKRJkhF0Mi1aJbA3y8+irTkRC6fUMQXGzSsVETql4ogCrRNa87DlxzFpu27uGzCDLbs0LBSEak/KoIo0at9GuMuHMCSNZu4/onZGlYqIvVGRRBFhvfI5vYzevPWx2X8+uVF+3+DiEgdaNRQlPne0Z1ZsW4LD01dRuc2KVx6TH7YkUQkyqkIotBtp/RiRcVW7nh5ER1bpzCioG3YkUQkiunQUBSKizP+fF5/Ds9J54YnZrNgpYaVisjBUxFEqeSkeMZfXEhGiyQumzCDVeu3hR1JRKKUiiCKZadWDyvdtnM3l02YwWYNKxWRg6AiiHI92qUy7sIBfFq2mesen0Xl7qqwI4lIlFERNAHHHpbFb87qwztLyrn9pYVE2x1lRSRcGjXURIwZ2Inl67bwwLtLyWvTgiuGdQk7kohECRVBE/Ljk3uyYt1W/vfVxeRmpHBy73ZhRxKRKKBDQ01IXJzxp/P60bdjK258cjbzSteHHUlEooCKoIlpnhjPQxcVktmyGZc/UkTpl1vDjiQijZyKoAnKSm3GPy45iu27qoeVrt28I+xIItKIBVoEZjbSzJaYWbGZ3VrL8uFmtsHM5kQevwgyTyzp3jaVB753JCsqtjL6/g9ZqQvORGQvAisCM4sHxgGnAAXAGDOrbRb2qe7eL/K4I6g8sWhI10wevXwQ5Zt3cO7fplFctjnsSCLSCNWpCMys0MyeN7NZZjbPzOab2bz9vG0gUOzuS919J/AkcOahBpYDU5iXwVNXDWbX7ipGP/Ah80t1XyIR+W913SN4DPgHcA5wOnBa5M99yQFKajwvjby2p8FmNtfM/m1mvWvbkJldZWZFZlZUXl5ex8jylYIOaTw9dgjJifGMeeg/TF+6LuxIItKI1LUIyt19krsvc/fPv3rs5z1Wy2t7XvI6C+js7n2BvwIv1LYhd3/Q3QvdvTArK6uOkaWm/MwWPHP1YNqlN+eihz/irY/XhB1JRBqJuhbBL81svJmNMbOzv3rs5z2lQG6N5x2BVTVXcPeN7r458vWrQKKZZdY1vByY9unJTPyfwfRol8pV/5zJi3NWhh1JRBqBuhbBpUA/YCTVh4S+Ojy0LzOA7maWb2ZJwPnApJormFk7M7PI1wMjeXTcIkAZLZJ47IpBFOa15vtPzeFfHy4PO5KIhKyut5jo6+6HH8iG3b3SzK4DpgDxwMPuvtDMxkaW3w+cC1xtZpXANuB81x3TApfaPJEJlw7kusdn8/MXF7J+6y6uO6EbkU4WkRhjdfm5a2YPAX9y99BnTC8sLPSioqKwYzQJlbur+NEz83hu9kquGJrPT0/tpTIQaaLMbKa7F9a2rK57BEOBi81sGbCD6hPB7u5H1FNGCUFCfBx3jepLWnIi499fxoZtu7jz7MNJiNcF5yKxpK5FMDLQFBKauDjjl6cXkJ6cyD1vfsqm7ZXcM6YfzRLiw44mIg1kv7/6mVkc8ErNYaN1HD4qUcLMuGnEYfzitAImL/yCyycUsUXTXorEjP0WgbtXAXPNrFMD5JEQXTY0n7tG9eXDpeu4cPx01m/dGXYkEWkAdT0Y3B5YaGZvmtmkrx5BBpNwnHtkR+67cACLVm3kvAf+Q9nG7WFHEpGA1XXU0HG1ve7u79Z7ov3QqKGGMa14LVf+s4g2LZvx6OWD6NQmJexIInII9jVqqE57BJEf+B8DqZHH4jBKQBrOkG6ZPHbl0Wzcvotz75/Gki82hR1JRAJS17uPjgY+AkYBo4HpZnZukMEkfP1yW/H0/wzGDEY/8CGzVnwZdiQRCUBdzxH8FDjK3S9294uovsX0z4OLJY1F97apPDN2CK1SEvnu+Om8/+nasCOJSD2raxHEuXtZjefrDuC9EuVyM1J4euxgOmWkcNmEGUxesDrsSCJSj+r6w3yymU0xs0vM7BLgFeDV4GJJY5Od2pynrhpMn5w0rnlsFhNnlOz/TSISFep6svgW4EHgCKAv8KC7/zjIYNL4pKck8ugVgzimWyY/enYe46cuDTuSiNSDut5iAnd/Fng2wCwSBVKSEhh/cSE3PzWX37yymA3bdnHziMN0szqRKFanIohMQvM7IJvqG859ddO5tACzSSPVLCGev4zpT2rzBP76VjEbtu3i9tN7ExenMhCJRnXdI/g9cLq7Lw4yjESP+DjjzrMPJz05kQfeW8qGbbu4a1RfEnXnUpGoU9ciWKMSkD2ZGbd9uxfpKYn8fvISNm2v5L4LB9A8UXcuFYkmdS2CIjN7iurJ5Xd89aK7PxdIKokq1wzvRnpyIj97YQEXPfwR4y8uJK15YtixRKSO6rofnwZsBb5F3ecslhhy4aDO3HN+f2Z9/iUXPPQf1m3esf83iUijUKc9Ane/NOggEv3O6NuB1GYJjH10JqMe+JBHLh1IboZuVifS2B3wmT0zm3UA6440syVmVmxmt+5jvaPMbLfuXxT9ju+ZzaNXDKJ80w5O++v7vLbwi7Ajich+HMwQjzqNETSzeGAccApQAIwxs4K9rPc7YMpBZJFG6Ki8DF66bii5Gclc9a+Z3PHSInZWVoUdS0T24mCK4JU6rjcQKHb3pe6+E3gSOLOW9a6n+kK1slqWSZTKy2zBs1cP4ZIheTz8wTLOvX8aK9ZtDTuWiNTigIvA3X9Wx1VzgJo3pCmNvPY1M8sBvgPcv68NmdlVZlZkZkXl5eUHEldC1CwhntvP6M393x3AsrVbOPUvU/n3fN2wTqSx2WcRmFmumT1pZlPN7Cdmllhj2Qv72XZth5D2nA7tz8CP3X33vjbk7g+6e6G7F2ZlZe3n20pjM7JPe169YRhdslty9WOz+MWLC9i+a59/5SLSgPa3R/Aw8A7Vh2/aA++aWZvIss77eW8pkFvjeUdg1R7rFAJPmtly4FzgPjM7a/+xJdrkZqTw9P8M5oqh+fzzw88552/TWLZ2S9ixRIT9F0GWu9/v7nPc/XrgPuA9M+vKN3+739MMoLuZ5ZtZEnA+8F8T3rt7vrvnuXse8Axwjbvvb09DolRSQhw/O62A8RcVUvrlNk77y1Qmzd3zdwMRaWj7K4JEM2v+1RN3fxS4keoRPu339UZ3rwSui6y7GJjo7gvNbKyZjT202BLNTipoy6s3DqNn+zRueGI2tz03X4eKREJk7nv/xd7MbgJm7TlRvZn1B37v7iMCzvcNhYWFXlRU1NDfVgKwa3cVd7/2Cfe/+xk926Vy7wUD6JbdMuxYIk2Smc1098Lalu1zj8Dd/7RnCURen03dh5GK1CoxPo5bT+nJhEuPomzTDs64932em1UadiyRmHMo9wy+ud5SSEwb3iObV28YRp+cdG6eOJcfPj2XrTsrw44lEjMOpQg0C4nUm3bpzXn8ikFcf0I3np1Vypn3fsAnazaFHUskJhxKEexv1JDIAUmIj+MH3+rBvy4bxJdbd3HGve8zcUYJ+zqPJSKHbn8XlG0ys421PDYBHRooo8SYod0zefXGoQzo1JofPTuPm56aw+YdOlQkEpT9nSxOdfe0Wh6p7l7nie9FDlR2anP+dfkgbjrpMCbNXcUZf32fRas2hh1LpEnSBLPSaMXHGTee1J3HrjiazTsqOeu+D3hs+uc6VCRSz1QE0ugN7tqGV28cxqD8DH76/AKue2I2m7bvCjuWSJOhIpCokNmyGY9cOpAfjezB5AVfcNpf32d+6YawY4k0CSoCiRpxccY1w7vx5FVHs7OyinP+No0JHyzToSKRQ6QikKhzVF4Gr9wwjKHdM7n9pUVc/egsNmzToSKRg6UikKiU0SKJ8RcV8tNv9+KNxWs49S9TmVOyPuxYIlFJRSBRKy7OuPLYLkwcOxh3OPdv0xg/dakOFYkcIBWBRL0BnVrz6g3DOKFnNr95ZTGXTphBSYXmRxapKxWBNAnpKYk88L0juf30Aj5aVsFJf3yXP7/xieY5EKkDFYE0GWbGJcfk8+YPjmNEQVv+/ManjPjTu7y+aI0OF4nsg4pAmpz26cnce8EAHr9iEM0T4rnyn0VcNmEGyzVHskitVATSZA3plsmrNw7jZ6f2YsbyL/nWn97jrilL2LZTh4tEagq0CMxspJktMbNiM7u1luVnmtk8M5tjZkVmNjTIPBJ7EuPjuGJYF976wXGcekR77n27mJP++C7/nr9ah4tEIvY5Z/EhbdgsHvgEGAGUAjOAMe6+qMY6LYEt7u5mdgTVE9z33Nd2NWexHIqPllXwixcX8PEXmxjWPZPbz+hN1yzNkyxN30HPWXyIBgLF7r7U3XcCTwJn1lzB3Tf7/zVRCzTZjQRsYH4GL18/lNtPL2BOyXpG/vk9fvvvj9mi+Q4khgVZBDlASY3npZHX/ouZfcfMPgZeAS4LMI8IUD0T2iXH5PPWD4ZzVr8c7n/3M068+11emrtKh4skJgVZBLXNafyN/2Xu/nzkcNBZwK9r3ZDZVZFzCEXl5eX1HFNiVVZqM/4wqi/PXj2EzNQkrn9iNhc8NF1zJUvMCbIISoHcGs87Aqv2trK7vwd0NbPMWpY96O6F7l6YlZVV/0klph3ZuTUvXjuU35zVh0WrN3LKPVP59cuLNOeBxIwgi2AG0N3M8s0sCTgfmFRzBTPrZmYW+XoAkASsCzCTSK3i44zvHt2Zt384nNGFuTz8wTJOuPtdnp9dqsNF0uQFVgTuXglcB0wBFlM9ImihmY01s7GR1c4BFpjZHGAccJ7rf52EKKNFEneefTgvXHMMHVolc9NTcxn9wIeaL1matMCGjwZFw0eloVRVOU/PLOF3k5ewfutOLhqcx00jDiM9OTHsaCIHLKzhoyJRLS7OOO+oTrz1g+P47tGd+eeHyznhrneYWFRCVVV0/QIlsi8qApH9aJWSxB1n9mHSdUPJy2zBj56Zxzn3T9OcydJkqAhE6qhPTjrPjB3M3aP6UlKxjTPGvc9Pn5/Pl1t2hh1N5JCoCEQOgJlxzpEdeeuHx3HpkHyenFHC8Xe/w+PTV7Bbh4skSqkIRA5CWvNEfnF6Aa/cMJTD2qbyk+fn8537PuC9T8o13FSijopA5BD0bJfGU1cdzT3n96N80w4uevgjzrj3AyYvWK0TyhI1NHxUpJ7sqNzNC7NX8rd3PmP5uq10y27JNcO7cnrfDiTG63cuCde+ho+qCETq2e4q59X5qxn3djEff7GJjq2T+Z/jujLqyI40T4wPO57EKBWBSAjcnbeXlHHvW8XMWrGezJbNuHJYPhce3ZmWzRLCjicxRkUgEiJ3Z/qyCsa9XczUT9eSnpzIxUPyuHRIHq1bJIUdT2KEikCkkZhXup5xbxczZeEaUpLiuWBgJ648tgtt05qHHU2aOBWBSCPz6ZpN/O2dz3hx7iriI9cmjD2uC53btAg7mjRRKgKRRqqkYisPvPcZE4tKqdxdxel9O3DN8G70aJcadjRpYlQEIo1c2cbt/P39ZTz6n8/ZsnM3Iwracs3wrvTv1DrsaNJEqAhEosT6rTt5ZNrn/GPaMtZv3cUx3dpw7fBuDO7ahsgcTiIHRUUgEmW27KjkiY9W8OB7SynbtIN+ua249vhunNgzm7g4FYIcOBWBSJTavms3z81ayf3vfsaKiq30aJvKNcd35dTD25Ogq5XlAKgIRKJc5e4qXp63mvveKeaTNZvplJHC2OO6cs6ROTRL0NXKsn8qApEmoqrKeWPxGsa9Xczc0g20TWvGlcO6MGZgJ1roamXZh9CKwMxGAvcA8cB4d//tHssvBH4ceboZuNrd5+5rmyoCkeqrlad9to5xbxcz7bN1tGyWwOl923PukbkM6NRKJ5blG/ZVBIH9CmFm8cA4YARQCswws0nuvqjGasuA49z9SzM7BXgQGBRUJpGmwsw4plsmx3TLZPaKL3ls+gpenLOKJz4qoWtWC0YX5vKdATlkp+qKZdm/wPYIzGwwcLu7nxx5fhuAu9+5l/VbAwvcPWdf29UegUjtNu+o5NV5q5lYVELR518SH2cc3yOLUYW5nNAzW7fCjnGh7BEAOUBJjeel7Pu3/cuBf9e2wMyuAq4C6NSpU33lE2lSWjZLYPRRuYw+KpfPyjfzzMxSnp1ZyhuLy2jTIonv9M9hVGGurlqWbwhyj2AUcLK7XxF5/j1goLtfX8u6xwP3AUPdfd2+tqs9ApG6q9xdxdRP1zKxqIQ3Fq9h126nb8d0RhXmcnrfDqQnJ4YdURpIWHsEpUBujecdgVV7rmRmRwDjgVP2VwIicmAS4uM4vmc2x/fMpmLLTl6YvZKJRSX87IUF/PrlRYzs047RhbkM7tJGF6rFsCD3CBKAT4ATgZXADOACd19YY51OwFvARe4+rS7b1R6ByKFxdxau2sjEohJemL2SjdsryWmVzLlHduTcIzuSm5ESdkQJQJjDR78N/Jnq4aMPu/v/mtlYAHe/38zGA+cAn0feUrm3oF9REYjUn+27dvP6ojVMLCrh/eK1uMOQrm0YXZjLyD7tNLVmE6ILykRkv1au38ZzM0t5emYpKyq2kto8gdP7dmB0YS59O6br2oQopyIQkTqrqqqeWvPpmSW8On8123dVcVjblow6Mpez+ueQldos7IhyEFQEInJQNm3fxcvzVvN0UQmzVqwnIc44oWc2owpzGd4jS9cmRBEVgYgcsuKyTTxdVMqzs1aydvMOMls24+wBOYzs045+HVtp1FEjpyIQkXqza3cV7y4p5+mZJby5uIzKKiezZTNO7JnNSQVtGdotk+QknWRubFQEIhKIDVt38c4nZbyxuIx3Pi5j045KmiXEMax7JiMK2nJCz7Y6p9BIqAhEJHA7K6uYsbyC1xet4fVFa1i5fhtm0C+3FSf1asuIgrZ0z26p0UchURGISINydz7+YhNvLFrDG4vXMLd0AwCdMlI4qVdbTirI5qi8DJ1sbkAqAhEJ1ZqN23lzcRlvLF7D+8Vr2VlZRXpyIsf3yOKkgrYce1gWac1136MgqQhEpNHYsqOSqZ+u5Y3Fa3jr4zIqtuwkMd44uksbTurVlhN7ZdOxtW5zUd9UBCLSKO2ucmav+JLXF1efV1havgWAXu3TGNGrehTS4Tm6qrk+qAhEJCp8Vr6ZNxev4Y1FZRR9XkGVQ9u0ZpHzCm0Z3KWN7n90kFQEIhJ1Krbs5O2Py3h90Rre+7ScrTt3k5IUz7HdvzqvkKmpOA+AikBEotr2Xbv5cOm6r0chrdm4A4AumS0Y1CWDgfkZDMpvQ4dWySEnbbxUBCLSZLg7C1ZuZNpna5m+rIIZyyvYtL0SgI6tkxmU34ZB+RkM6pJBp4wUnV+IUBGISJO1u8pZvHojHy2rqH4sr6Biy04A2qU1Z2B+9R7D0V0y6JoVuxe0qQhEJGZUVTmflW/mP5FimL50HWWbqg8ltWmR9HUxDMpvQ892qTFzs7yw5iwWEWlwcXFG97apdG+byveO7oy78/m6rUxfto7pyyqYvrSCfy/4AoC05gn/VQy9O6SREINXO6sIRKRJMzPyMluQl9mC847qBEDpl1u/PpQ0fVkFbywuA6BFUjxH5mVUn2PIz+Dwjuk0S2j6w1VVBCISczq2TqFj6xTOHtARgLKN25n+dTGs4w9TlgDQLCGOAZ1aV+8xdMmgf27rJnmL7aAnrx8J3EP15PXj3f23eyzvCfwDGAD81N3v2t82dY5ARIJWsWVnjT2GdSxavRF3SIw3DozZ+LEAAAcdSURBVM9J54iOrejdIY0+Oel0y24ZFTfPC+UcgZnFA+OAEUApMMPMJrn7ohqrVQA3AGcFlUNE5EBltEhiZJ92jOzTDoAN23Yx8/Pqw0gzl3/JxKIStu7cDUBSQhy92qXSOyedPh3S6ZOTxmFtU6PqCuggDw0NBIrdfSmAmT0JnAl8XQTuXgaUmdmpAeYQETkk6cmJnNCzeqIdqB6yunzdFhas3MDCVRtZsHIDL89dxePTVwCQEDlh3Sey19AnJ41e7dNISWqcR+ODTJUDlNR4XgoMOpgNmdlVwFUAnTp1OvRkIiKHID7O6JrVkq5ZLTmzXw5QfaFbScU2FqzawIKVG1iwaiNvfVzG0zNLAYgz6JLV8uty6N0hnd45aY3i9ttBFkFtg3MP6oSEuz8IPAjV5wgOJZSISBDMjE5tUujUJoVvH94eqC6HLzZuZ8HKjZG9hw38Z2kFL8xZ9fX7OrdJoU+kFA6PFERGi6QGzR5kEZQCuTWedwRW7WVdEZEmx8xon55M+/RkRhS0/fr18k07WLjq/w4rzVu5nlfmr/56eU6r5K9PRvfJSaNPh3Sy04K7wV6QRTAD6G5m+cBK4HzgggC/n4hIVMhKbcbwHtkM75H99Wsbtu5i4aoNkUNLG1mwagOvL17DVwM7M1s2Y+xxXbhiWJd6zxNYEbh7pZldB0yhevjow+6+0MzGRpbfb2btgCIgDagys+8DBe6+MahcIiKNUXpKIkO6ZTKkW+bXr23eUcni1dV7DQtWbiQrtVkg31v3GhIRiQH7uo6g8V8FISIigVIRiIjEOBWBiEiMUxGIiMQ4FYGISIxTEYiIxDgVgYhIjFMRiIjEuKi7oMzMyoHPw85xEDKBtWGHaGD6zE1frH1eiN7P3Nnds2pbEHVFEK3MrGhvV/U1VfrMTV+sfV5omp9Zh4ZERGKcikBEJMapCBrOg2EHCIE+c9MXa58XmuBn1jkCEZEYpz0CEZEYpyIQEYlxKoIAmVmumb1tZovNbKGZ3Rh2poZiZvFmNtvMXg47S0Mws1Zm9oyZfRz5+x4cdqagmdlNkX/XC8zsCTMLblLdkJjZw2ZWZmYLaryWYWavm9mnkT9bh5mxPqgIglUJ/MDdewFHA9eaWUHImRrKjcDisEM0oHuAye7eE+hLE//sZpYD3AAUunsfqqejPT/cVIGYAIzc47VbgTfdvTvwZuR5VFMRBMjdV7v7rMjXm6j+4ZATbqrgmVlH4FRgfNhZGoKZpQHHAn8HcPed7r4+3FQNIgFINrMEIAVYFXKeeufu7wEVe7x8JvBI5OtHgLMaNFQAVAQNxMzygP7A9HCTNIg/Az8CqsIO0kC6AOXAPyKHw8abWYuwQwXJ3VcCdwErgNXABnd/LdxUDaatu6+G6l/2gOyQ8xwyFUEDMLOWwLPA9919Y9h5gmRmpwFl7j4z7CwNKAEYAPzN3fsDW2gChwv2JXJc/EwgH+gAtDCz74abSg6WiiBgZpZIdQk85u7PhZ2nARwDnGFmy4EngRPM7NFwIwWuFCh196/29p6huhiaspOAZe5e7u67gOeAISFnaihrzKw9QOTPspDzHDIVQYDMzKg+brzY3f8Ydp6G4O63uXtHd8+j+uThW+7epH9TdPcvgBIz6xF56URgUYiRGsIK4GgzS4n8Oz+RJn6CvIZJwMWRry8GXgwxS71ICDtAE3cM8D1gvpnNibz2E3d/NcRMEozrgcfMLAlYClwacp5Auft0M3sGmEX16LjZNMVbL5g9AQwHMs2sFPgl8FtgopldTnUhjgovYf3QLSZERGKcDg2JiMQ4FYGISIxTEYiIxDgVgYhIjFMRiIjEOBWBSB2Z2Vk1bxpoZneY2UlhZhKpDxo+KlJHZjYBeNndnwk7i0h90h6BxDwzy4vMIfBQ5P76r5lZ8h7rDAHOAP5gZnPMrKuZTTCzcyPLl5vZ/zOzD82syMwGmNkUM/vMzMbW2M4tZjbDzOaZ2a8ir7Uws1fMbG7k3v7nNeTnF1ERiFTrDoxz997AeuCcmgvdfRrVtxa4xd37uftntWyjxN0HA1Opvo/9uVTPQ3EHgJl9K/J9BgL9gCPN7Fiq73e/yt37Ru7tPzmAzyeyV7rFhEi1Ze7+1W1AZgJ5B7GNSZE/5wMtI3NQbDKz7WbWCvhW5DE7sl5LqothKnCXmf2O6kNPUw/yM4gcFBWBSLUdNb7eDSTvbcU6bKNqj+1VUf1/zYA73f2BPd9oZkcC3wbuNLPX3P2Og/j+IgdFh4ZE6m4TkHoI758CXBaZnwIzyzGzbDPrAGx190epnuylqd/CWhoZ7RGI1N2TwENmdgPVx/8PiLu/Zma9gA+r79zMZuC7QDeqT0JXAbuAq+svssj+afioiEiM06EhEZEYpyIQEYlxKgIRkRinIhARiXEqAhGRGKciEBGJcSoCEZEY9/8BgGZMn+DXmHAAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "val = [W]\n",
    "for i in range(10):\n",
    "    val.append([val[-1]@W])\n",
    "# 计算 L2 范数\n",
    "norm = list(map(lambda x:tf.norm(x).numpy(),val))\n",
    "plt.plot(range(1,12),norm)\n",
    "plt.xlabel('n times')\n",
    "plt.ylabel('L2-norm')\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 梯度裁剪"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "1. 直接对张量的数值进行限幅， 使得张量$W$的所有元素$w_{ij} \\in [min, max]$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=8044, shape=(2, 2), dtype=float32, numpy=\n",
       "array([[0.58690715, 0.6       ],\n",
       "       [0.6       , 0.6       ]], dtype=float32)>"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a=tf.random.uniform([2,2])\n",
    "# 梯度值裁剪\n",
    "tf.clip_by_value(a,0.4,0.6) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "2. 通过限制梯度张量$W$的范数来实现梯度裁剪"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(<tf.Tensor: id=8075, shape=(), dtype=float32, numpy=6.7737527>,\n",
       " <tf.Tensor: id=8080, shape=(), dtype=float32, numpy=5.0000005>)"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a=tf.random.uniform([2,2]) * 5\n",
    "# 按范数方式裁剪\n",
    "b = tf.clip_by_norm(a, 5)\n",
    "# 裁剪前和裁剪后的张量范数\n",
    "tf.norm(a),tf.norm(b)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "3. 全局范数裁剪"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf.Tensor(3.5118904, shape=(), dtype=float32) tf.Tensor(1.9999998, shape=(), dtype=float32)\n"
     ]
    }
   ],
   "source": [
    "# 创建梯度张量 1\n",
    "w1=tf.random.normal([3,3]) \n",
    "# 创建梯度张量 2\n",
    "w2=tf.random.normal([3,3]) \n",
    "# 计算 global norm\n",
    "global_norm=tf.math.sqrt(tf.norm(w1)**2+tf.norm(w2)**2)\n",
    "# 根据 global norm 和 max norm=2 裁剪\n",
    "(ww1,ww2),global_norm=tf.clip_by_global_norm([w1,w2],2)\n",
    "# 计算裁剪后的张量组的 global norm\n",
    "global_norm2 = tf.math.sqrt(tf.norm(ww1)**2+tf.norm(ww2)**2)\n",
    "# 打印裁剪前的全局范数和裁剪后的全局范数\n",
    "print(global_norm, global_norm2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## LSTM 层使用方法\n",
    "\n",
    "### LSTMCell"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(552212520, 552212520, 552240520)"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = tf.random.normal([2,80,100])\n",
    "# 得到一个时间戳的输入\n",
    "xt = x[:,0,:] \n",
    "cell = layers.LSTMCell(64) # 创建 LSTM Cell\n",
    "# 初始化状态和输出 List,[h,c]\n",
    "state = [tf.zeros([2,64]),tf.zeros([2,64])]\n",
    "out, state = cell(xt, state) # 前向计算\n",
    "# 查看返回元素的 id\n",
    "id(out),id(state[0]),id(state[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 在序列长度维度上解开，循环送入 LSTM Cell 单元\n",
    "for xt in tf.unstack(x, axis=1):\n",
    "    # 前向计算\n",
    "    out, state = cell(xt, state)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### LSTM 层"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 创建一层 LSTM 层，内存向量长度为 64\n",
    "layer = layers.LSTM(64)\n",
    "# 序列通过 LSTM 层，默认返回最后一个时间戳的输出 h\n",
    "out = layer(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 创建 LSTM 层时，设置返回每个时间戳上的输出\n",
    "layer = layers.LSTM(64, return_sequences=True)\n",
    "# 前向计算，每个时间戳上的输出自动进行了 concat，拼成一个张量\n",
    "out = layer(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=10693, shape=(2, 64), dtype=float32, numpy=\n",
       "array([[-7.29653612e-02,  1.15865074e-01, -7.15253651e-02,\n",
       "         7.39371553e-02,  4.68634143e-02,  5.46397939e-02,\n",
       "         3.66614596e-03, -3.41411717e-02, -9.88797173e-02,\n",
       "        -7.74470195e-02,  9.54263564e-03,  6.62799105e-02,\n",
       "         5.46958186e-02, -9.72819999e-02, -5.92009760e-02,\n",
       "         8.12323615e-02, -2.34329626e-02, -7.07291812e-02,\n",
       "         2.66989227e-02,  4.51712869e-02,  1.57007724e-01,\n",
       "        -7.51742488e-03,  1.35744780e-01,  1.02491170e-01,\n",
       "         4.77755927e-02, -7.14057460e-02, -5.32417037e-02,\n",
       "         3.11367679e-03,  1.98544171e-02,  1.20310985e-01,\n",
       "        -3.45447473e-02,  9.80208740e-02, -2.00126737e-01,\n",
       "         3.35342549e-02,  1.09030381e-02, -8.74172971e-02,\n",
       "         1.21381782e-01, -1.21675864e-01, -4.28368524e-02,\n",
       "        -1.63171720e-03,  7.50030763e-03,  6.19905367e-02,\n",
       "         6.91571087e-02, -1.79095380e-02,  2.11950559e-02,\n",
       "        -2.99647055e-03,  2.53018010e-02, -1.27889946e-01,\n",
       "         2.90306136e-02,  2.04134416e-02, -5.84693067e-02,\n",
       "        -8.47572373e-05,  2.86764801e-02, -1.46667600e-01,\n",
       "         4.19261158e-02, -8.66340473e-02,  5.49870357e-02,\n",
       "        -1.69840395e-01,  9.46776941e-02, -7.93699250e-02,\n",
       "        -8.68466422e-02,  1.85655400e-01, -6.72442019e-02,\n",
       "        -3.39936800e-02],\n",
       "       [ 1.35182619e-01,  1.89286768e-01,  1.39465556e-01,\n",
       "        -1.31353527e-01, -5.33331931e-02,  1.66756541e-01,\n",
       "        -9.98248756e-02,  1.16001211e-01,  2.07160320e-02,\n",
       "         1.05367722e-02,  9.57108960e-02,  5.47952540e-02,\n",
       "         9.50262770e-02,  7.72577338e-03,  3.10489093e-03,\n",
       "        -1.58313308e-02, -4.08570543e-02,  8.74623656e-02,\n",
       "        -1.00237563e-01, -4.59164865e-02,  3.95013466e-02,\n",
       "        -9.21263173e-02, -6.73963130e-02,  3.45224701e-02,\n",
       "        -1.24158934e-01,  8.55000094e-02, -5.82238659e-02,\n",
       "        -9.40350741e-02,  7.63973547e-03, -2.28851154e-01,\n",
       "        -4.19710018e-02,  1.56359170e-02, -2.00016471e-03,\n",
       "        -7.04211891e-02, -1.29387090e-02, -1.37450963e-01,\n",
       "         6.81063533e-02, -1.01319902e-01, -5.00466265e-02,\n",
       "        -1.31414577e-01,  9.37328711e-02, -6.82587847e-02,\n",
       "         7.09251016e-02,  5.85846305e-02, -5.06423190e-02,\n",
       "        -9.98194739e-02,  1.12542873e-02,  1.20005630e-01,\n",
       "         1.95571512e-01,  8.41921270e-02,  3.75407264e-02,\n",
       "         7.91881159e-02,  5.79719506e-02, -1.26136616e-01,\n",
       "        -1.39808372e-01,  1.38167590e-01, -1.76716790e-01,\n",
       "        -5.94828129e-02, -9.41112265e-03,  4.15128320e-02,\n",
       "        -5.21374717e-02,  4.28842120e-02,  4.89364192e-02,\n",
       "         7.69282505e-02]], dtype=float32)>"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 和 CNN 网络一样， LSTM 也可以简单地层层堆叠\n",
    "net = keras.Sequential([\n",
    "    layers.LSTM(64, return_sequences=True), # 非末层需要返回所有时间戳输出\n",
    "    layers.LSTM(64)\n",
    "])\n",
    "# 一次通过网络模型，即可得到最末层、最后一个时间戳的输出\n",
    "out = net(x)\n",
    "out"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## GRU 使用方法"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "TensorShape([2, 64])"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 初始化状态向量， GRU 只有一个\n",
    "h = [tf.zeros([2,64])]\n",
    "# 新建 GRU Cell，向量长度为 64\n",
    "cell = layers.GRUCell(64) \n",
    "# 在时间戳维度上解开，循环通过 cell\n",
    "for xt in tf.unstack(x, axis=1):\n",
    "    out, h = cell(xt, h)\n",
    "# 输出形状\n",
    "out.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "通过 layers.GRU 类可以方便创建一层 GRU 网络层，通过 Sequential 容器可以堆叠多层 GRU 层的网络。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "net = keras.Sequential([\n",
    "    layers.GRU(64, return_sequences=True),\n",
    "    layers.GRU(64)\n",
    "])\n",
    "out = net(x)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.2"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": true,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
